{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "---\n",
        "sidebar_position: 3\n",
        "---\n",
        "\n",
        "# How to create a custom LLM class\n",
        "\n",
        "```{=mdx}\n",
        ":::info Prerequisites\n",
        "\n",
        "This guide assumes familiarity with the following concepts:\n",
        "\n",
        "- [LLMs](/docs/concepts/text_llms)\n",
        "\n",
        ":::\n",
        "```\n",
        "\n",
        "This notebook goes over how to create a custom LLM wrapper, in case you want to use your own LLM or a different wrapper than one that is directly supported in LangChain.\n",
        "\n",
        "There are a few required things that a custom LLM needs to implement after extending the [`LLM` class](https://api.js.langchain.com/classes/langchain_core.language_models_llms.LLM.html):\n",
        "\n",
        "- A `_call` method that takes in a string and call options (which includes things like `stop` sequences), and returns a string.\n",
        "- A `_llmType` method that returns a string. Used for logging purposes only.\n",
        "\n",
        "You can also implement the following optional method:\n",
        "\n",
        "- A `_streamResponseChunks` method that returns an `AsyncIterator` and yields [`GenerationChunks`](https://api.js.langchain.com/classes/langchain_core.outputs.GenerationChunk.html). This allows the LLM to support streaming outputs.\n",
        "\n",
        "Let's implement a very simple custom LLM that just echoes back the first `n` characters of the input."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {},
      "outputs": [],
      "source": [
        "import { LLM, type BaseLLMParams } from \"@langchain/core/language_models/llms\";\n",
        "import type { CallbackManagerForLLMRun } from \"@langchain/core/callbacks/manager\";\n",
        "import { GenerationChunk } from \"@langchain/core/outputs\";\n",
        "\n",
        "interface CustomLLMInput extends BaseLLMParams {\n",
        "  n: number;\n",
        "}\n",
        "\n",
        "class CustomLLM extends LLM {\n",
        "  n: number;\n",
        "\n",
        "  constructor(fields: CustomLLMInput) {\n",
        "    super(fields);\n",
        "    this.n = fields.n;\n",
        "  }\n",
        "\n",
        "  _llmType() {\n",
        "    return \"custom\";\n",
        "  }\n",
        "\n",
        "  async _call(\n",
        "    prompt: string,\n",
        "    options: this[\"ParsedCallOptions\"],\n",
        "    runManager: CallbackManagerForLLMRun\n",
        "  ): Promise<string> {\n",
        "    // Pass `runManager?.getChild()` when invoking internal runnables to enable tracing\n",
        "    // await subRunnable.invoke(params, runManager?.getChild());\n",
        "    return prompt.slice(0, this.n);\n",
        "  }\n",
        "\n",
        "  async *_streamResponseChunks(\n",
        "    prompt: string,\n",
        "    options: this[\"ParsedCallOptions\"],\n",
        "    runManager?: CallbackManagerForLLMRun\n",
        "  ): AsyncGenerator<GenerationChunk> {\n",
        "    // Pass `runManager?.getChild()` when invoking internal runnables to enable tracing\n",
        "    // await subRunnable.invoke(params, runManager?.getChild());\n",
        "    for (const letter of prompt.slice(0, this.n)) {\n",
        "      yield new GenerationChunk({\n",
        "        text: letter,\n",
        "      });\n",
        "      // Trigger the appropriate callback\n",
        "      await runManager?.handleLLMNewToken(letter);\n",
        "    }\n",
        "  }\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can now use this as any other LLM:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "I am\n"
          ]
        }
      ],
      "source": [
        "const llm = new CustomLLM({ n: 4 });\n",
        "\n",
        "await llm.invoke(\"I am an LLM\");"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "And support streaming:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "I\n",
            " \n",
            "a\n",
            "m\n"
          ]
        }
      ],
      "source": [
        "const stream = await llm.stream(\"I am an LLM\");\n",
        "\n",
        "for await (const chunk of stream) {\n",
        "  console.log(chunk);\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Richer outputs\n",
        "\n",
        "If you want to take advantage of LangChain's callback system for functionality like token tracking, you can extend the [`BaseLLM`](https://api.js.langchain.com/classes/langchain_core.language_models_llms.BaseLLM.html) class and implement the lower level\n",
        "`_generate` method. Rather than taking a single string as input and a single string output, it can take multiple input strings and map each to multiple string outputs.\n",
        "Additionally, it returns a `Generation` output with fields for additional metadata rather than just a string."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {},
      "outputs": [],
      "source": [
        "import { CallbackManagerForLLMRun } from \"@langchain/core/callbacks/manager\";\n",
        "import { LLMResult } from \"@langchain/core/outputs\";\n",
        "import {\n",
        "  BaseLLM,\n",
        "  BaseLLMCallOptions,\n",
        "  BaseLLMParams,\n",
        "} from \"@langchain/core/language_models/llms\";\n",
        "\n",
        "interface AdvancedCustomLLMCallOptions extends BaseLLMCallOptions {}\n",
        "\n",
        "interface AdvancedCustomLLMParams extends BaseLLMParams {\n",
        "  n: number;\n",
        "}\n",
        "\n",
        "class AdvancedCustomLLM extends BaseLLM<AdvancedCustomLLMCallOptions> {\n",
        "  n: number;\n",
        "\n",
        "  constructor(fields: AdvancedCustomLLMParams) {\n",
        "    super(fields);\n",
        "    this.n = fields.n;\n",
        "  }\n",
        "\n",
        "  _llmType() {\n",
        "    return \"advanced_custom_llm\";\n",
        "  }\n",
        "\n",
        "  async _generate(\n",
        "    inputs: string[],\n",
        "    options: this[\"ParsedCallOptions\"],\n",
        "    runManager?: CallbackManagerForLLMRun\n",
        "  ): Promise<LLMResult> {\n",
        "    const outputs = inputs.map((input) => input.slice(0, this.n));\n",
        "    // Pass `runManager?.getChild()` when invoking internal runnables to enable tracing\n",
        "    // await subRunnable.invoke(params, runManager?.getChild());\n",
        "\n",
        "    // One input could generate multiple outputs.\n",
        "    const generations = outputs.map((output) => [\n",
        "      {\n",
        "        text: output,\n",
        "        // Optional additional metadata for the generation\n",
        "        generationInfo: { outputCount: 1 },\n",
        "      },\n",
        "    ]);\n",
        "    const tokenUsage = {\n",
        "      usedTokens: this.n,\n",
        "    };\n",
        "    return {\n",
        "      generations,\n",
        "      llmOutput: { tokenUsage },\n",
        "    };\n",
        "  }\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "This will pass the additional returned information in callback events and in the `streamEvents method:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "{\n",
            "  \"event\": \"on_llm_end\",\n",
            "  \"data\": {\n",
            "    \"output\": {\n",
            "      \"generations\": [\n",
            "        [\n",
            "          {\n",
            "            \"text\": \"I am\",\n",
            "            \"generationInfo\": {\n",
            "              \"outputCount\": 1\n",
            "            }\n",
            "          }\n",
            "        ]\n",
            "      ],\n",
            "      \"llmOutput\": {\n",
            "        \"tokenUsage\": {\n",
            "          \"usedTokens\": 4\n",
            "        }\n",
            "      }\n",
            "    }\n",
            "  },\n",
            "  \"run_id\": \"a9ce50e4-f85b-41eb-bcbe-793efc52f9d8\",\n",
            "  \"name\": \"AdvancedCustomLLM\",\n",
            "  \"tags\": [],\n",
            "  \"metadata\": {}\n",
            "}\n"
          ]
        }
      ],
      "source": [
        "const llm = new AdvancedCustomLLM({ n: 4 });\n",
        "\n",
        "const eventStream = await llm.streamEvents(\"I am an LLM\", {\n",
        "  version: \"v2\",\n",
        "});\n",
        "\n",
        "for await (const event of eventStream) {\n",
        "  if (event.event === \"on_llm_end\") {\n",
        "    console.log(JSON.stringify(event, null, 2));\n",
        "  }\n",
        "}"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "TypeScript",
      "language": "typescript",
      "name": "tslab"
    },
    "language_info": {
      "codemirror_mode": {
        "mode": "typescript",
        "name": "javascript",
        "typescript": true
      },
      "file_extension": ".ts",
      "mimetype": "text/typescript",
      "name": "typescript",
      "version": "3.7.2"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 2
}